-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix static generation when compiling! #28937
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I'm not sure adding a new argument The following works on import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM", attn_implementation="eager")
tokenizer = LlamaTokenizer.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
# random input id
inputs = tokenizer("Hey there", return_tensors="pt", return_attention_mask=True)
position_ids = inputs.attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(inputs.attention_mask == 0, 1)
with torch.no_grad():
logits = model.forward(**inputs, position_ids=position_ids).logits If we run the same code on this PR, we get the following error:
Full traceback:
This is because transformers/src/transformers/models/llama/modeling_llama.py Lines 352 to 353 in 56768a0
instead of reshaping to [ :, :, cache_position, : key_states.shape[-2]] , we reshape to [ :, :, None, : key_states.shape[-2]] . So instead of slicing, we insert an extra dimension! This gives the size mismatch when we add the attention mask to the weights. The user needs to specify cache_position as an argument to the forward call in order for this to work.
Overall, I think we should avoid adding extra arguments that require code changes from the user, especially to the top-level modules which are already highly-used. What about a design more like Flax where we keep track of the |
We can make it BC! this PR is not ready yet, but generate should check the past key value class and if signature can take cache_position, give them. Something like that. I'll work on making it BC! :) |
past_seen_tokens = 0 | ||
if use_cache and not isinstance(past_key_values, Cache): | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_seen_tokens = past_key_values.get_usable_length(inputs_embeds.shape[1]) # kept for BC (cache positions) | ||
|
||
if cache_position is None: | ||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens+inputs_embeds.shape[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has to be kept for BC
if attention_mask is None: | ||
return None | ||
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) | ||
if not is_tracing and (torch.all(attention_mask == 1)): | ||
return None | ||
if is_tracing and seq_length == 1: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of this failed generations, deal with it later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @fxmarty I am warning you in advance 🥶 you might have to do something similar to the prepared_4d_sdpa
but this is a lot simpler so for the better
# TODO @gante we should only keep a `cache_position` in generate, and do +=1. | ||
# same goes for position ids. Could also help with continued generation. | ||
cache_position = kwargs.get("cache_position", None) | ||
if cache_position is None: | ||
cache_position = torch.arange(past_length, past_length+input_ids.shape[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kept for BC as well, generate should handle cache positions IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pre-approving, as the overall PR shape looks good to me 👍
(btw, this PR is blocking further work on generate, as llama + generate + dynamic cache is not correct at the moment and I want to standardize the interface of the different cache classes to match the static cache)
Thanks, merging asap |
bool_keys = [k for k in keys if isinstance(model_input[k], bool)] | ||
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"] | ||
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] | ||
keys_to_ignore = ["cache_position", "encoder_outputs"] | ||
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
beam search will split the cache positions otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the huge work ! I left some minor comments that should be addressed before merging IMO, otherwise we might introduce some breaking change for users that use our public classes without explicit positional arguments
Example of a breaking behaviour that I introduced while working on FA2: #25598 (comment) so we should be careful when adding new args in our modules |
Co-authored-by: Younes Belkada <[email protected]>
…hub.com:huggingface/transformers into fix-static-kv-cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much !
* wow I was scared! * fix everything * nits * make it BC? * add todo * nits * is_tracing should still be used to pass tracing tests * nits * some nits to make sure genration works with static cache uncompiled * fix sdpa * fix FA2 for both static and dynamic in a better way? * style * fix-copies * fix fix copies * fix sequential beam searcg * style * use `keys_to_ignore` * nit * correct dtype inference when init * :( the fix for FA2 is still not optimal to investigate! * styling * nits * nit * this might work better * add comment * Update src/transformers/models/llama/modeling_llama.py * "position_ids" -> "cache_position" * style * nit * Remove changes that should no be propagatted just yet * Apply suggestions from code review * Styling * make sure we raise an errir for static cache with FA2 enabled * move to the bottom of the signature * style * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <[email protected]> * Update src/transformers/models/llama/modeling_llama.py * nit in the name --------- Co-authored-by: Younes Belkada <[email protected]>
Hey @ArthurZucker, I discovered that this change actually breaks TPU... Now, TPU training with FSDPv2 will produce loss with NaN. I haven't looked into your PR so I'm not sure why. Just bisecting til this change. |
Mmm this might be a ROPE issue? #29109 might also play |
Hi @ArthurZucker I run your benchmark script with both transformers 4.38.0 and 4.38.2 but got error:
|
It is probably out of date! I'll update it |
We'll actually push a full benchmark in |
* wow I was scared! * fix everything * nits * make it BC? * add todo * nits * is_tracing should still be used to pass tracing tests * nits * some nits to make sure genration works with static cache uncompiled * fix sdpa * fix FA2 for both static and dynamic in a better way? * style * fix-copies * fix fix copies * fix sequential beam searcg * style * use `keys_to_ignore` * nit * correct dtype inference when init * :( the fix for FA2 is still not optimal to investigate! * styling * nits * nit * this might work better * add comment * Update src/transformers/models/llama/modeling_llama.py * "position_ids" -> "cache_position" * style * nit * Remove changes that should no be propagatted just yet * Apply suggestions from code review * Styling * make sure we raise an errir for static cache with FA2 enabled * move to the bottom of the signature * style * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <[email protected]> * Update src/transformers/models/llama/modeling_llama.py * nit in the name --------- Co-authored-by: Younes Belkada <[email protected]>
What does this PR do?
Fixes the static cache generation. Comes with #27931
thanks @OlivierDehaene for the insight
https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03 benchmark
generate
because the first forward will be fully causal.FA2 potential fix if compiled worked:
but I have slowdowns:
Slicing
vs no Slicing